import torch
from einops import rearrange
from enformer_pytorch import Enformer
from torch import nn


class EnformerPretrainedModel(nn.Module):
    """
    Enformer conv and transformer layers followed by pooling.
    """

    def __init__(self, n_tasks, depth=11):
        super().__init__()
        self.n_tasks = n_tasks

        model = Enformer.from_pretrained(
            "EleutherAI/enformer-official-rough", target_length=-1
        )
        self.stem = model.stem
        self.conv_tower = model.conv_tower
        self.transformer = model.transformer[:depth]
        self.final_pointwise = model.final_pointwise
        self.linear = nn.Linear(3072, n_tasks, bias=True)
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        x = self.stem(x)
        x = self.conv_tower(x)
        x = torch.swapaxes(x, 1, 2)
        x = self.transformer(x)
        x = self.final_pointwise(x)
        x = self.linear(x)
        x = torch.swapaxes(x, 1, 2)
        x = self.pool(x)
        return x.squeeze(-1)


class EnformerModel(nn.Module):

    def __init__(self, n_tasks, **kwargs):
        super().__init__()
        self.n_tasks = n_tasks

        model = Enformer.from_hparams(**kwargs)

        self.stem = model.stem
        self.conv_tower = model.conv_tower
        self.transformer = model.transformer
        self.final_pointwise = model.final_pointwise
        self.linear = nn.Linear(dim * 2, n_tasks, bias=True)
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        x = self.stem(x)
        x = self.conv_tower(x)
        x = torch.swapaxes(x, 1, 2)
        x = self.transformer(x)
        x = self.final_pointwise(x)
        x = self.linear(x)
        x = torch.swapaxes(x, 1, 2)
        x = self.pool(x)
        return x.squeeze(-1)
